#!/bin/bash

# Copyright 2019 Google Inc. All Rights Reserved. Licensed under the Apache
# License, Version 2.0 (the "License"); you may not use this file except in
# compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.

set -e

# Parse command line arguments
WORK_DIR=/tmp/cloudml-samples/molecules
MAX_DATA_FILES=5
while [[ $# -gt 0 ]]; do
  case $1 in
    --work-dir)
      WORK_DIR=$2
      shift
      ;;
    --max-data-files)
      MAX_DATA_FILES=$2
      shift
      ;;
    *)
      echo "error: unrecognized argument $1"
      exit 1
      ;;
  esac
  shift
done

# Wrapper function to print the command being run
function run {
  echo "$ $@"
  "$@"
}

# Extract the data files
echo '>> Extracting data'
run python data-extractor.py \
  --work-dir $WORK_DIR \
  --max-data-files $MAX_DATA_FILES
echo ''

# Preprocess the datasets
echo '>> Preprocessing'
run python preprocess.py \
  --work-dir $WORK_DIR
echo ''

# Train and evaluate the model
echo '>> Training'
run python trainer/task.py \
  --work-dir $WORK_DIR
echo ''

# Get the model path
EXPORT_DIR=$WORK_DIR/model/export/final
if [[ $EXPORT_DIR == gs://* ]]; then
  MODEL_DIR=$(gsutil ls -d "$EXPORT_DIR/*" | sort -r | head -n 1)
else
  MODEL_DIR=$(ls -d -1 $EXPORT_DIR/* | sort -r | head -n 1)
fi
echo "Model: $MODEL_DIR"
echo ''

# Make batch predictions on SDF files
echo '>> Batch prediction'
run python predict.py \
  --work-dir $WORK_DIR \
  --model-dir $MODEL_DIR \
  batch \
  --inputs-dir $WORK_DIR/data \
  --outputs-dir $WORK_DIR/predictions

# Display some predictions
if [[ $WORK_DIR == gs://* ]]; then
  gsutil cat $WORK_DIR/predictions/* | head -n 10
else
  head -n 10 $WORK_DIR/predictions/*
fi
